-
Notifications
You must be signed in to change notification settings - Fork 223
Feat/add lora for sglangjax #826
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
32a06fd to
3af36c3
Compare
3ea01a4 to
a56ef3a
Compare
3dfcb05 to
bb7bd21
Compare
|
Hi @aolemila , thank you for the PR! Can you rebase to head and resolve the conflicts? We've removed the sglang script so it should merge into our main script. And please squash the commits. |
wang2yn84
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your PR! Left some comments.
|
Hi @wang2yn84 , thanks for your reply. I will rebase the main and modify codes according to your advice. |
|
I am rerunning scripts and fix new problems I meet. |
bb7bd21 to
e32615c
Compare
e32615c to
3f755c5
Compare
3f755c5 to
874bfbc
Compare
|
Hi, @wang2yn84 . I have updated codes according to your suggestions. In addition to modifications, I have passed three test cases. You can see more details in PR descriptions.
|
874bfbc to
fd80b54
Compare
tunix/generate/sglang_jax_sampler.py
Outdated
| new_model_state_leaves, _ = jax.tree_util.tree_flatten(new_state) | ||
| self._model_runner.model_state_leaves = new_model_state_leaves | ||
|
|
||
| flatten_src_to_tgt_module_name = os.getenv(VERIFY_UPDATE_PARAMS_KEY, None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part of validation should belong to the test instead of the production code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed here: commit.
| @@ -0,0 +1,371 @@ | |||
| """ | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I didn't look into the details in the last round of review. Seems this integration test is quite heavy, using 3B model to run the whole GRPO workflow. Such test better go to nightly regression. In CI, can we have some lightweight validation such just test update_param?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. I add python scripts/grpo_demo_llama3_qwen2.py --num-batches 2 --num-test-batches 1 --root-dir=/home/gcpuser/aolemila --rollout-engine sglang_jax --enable-lora --lora-target-modules all in tpu-nightly-regression.yml to run LoRA case. For tests/generate/sglang_jax_lora_test.py in tpu-tests.yml, I will simplify it and make it more lightweight.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed here: commit.
| return text.split("####")[1].strip() | ||
|
|
||
|
|
||
| def download_kaggle_dataset(target_dir="./data/gsm8k"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we leverage the existing API? We have dataset loading and get lora model APIs. No need to recreate these functions again. If the existing API is not sufficient, say there is no other dataset support, can you help improve the existing API maybe in a separate PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. These codes are based on old version scripts/grpo_demo_llama3_qwen2.py, and maybe they are outdated. I will follow the latest scripts/grpo_demo_llama3_qwen2.py to use recommended APIs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not used in simpified version. Fixed here: commit.
| # List of batch sizes buckets for jax jit | ||
| rollout_sglang_jax_precompile_bs_paddings: Optional[List] = None | ||
| # Whether to use lora | ||
| rollout_sglang_jax_enable_static_lora: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this suppose to be True? IIUC, the way Tunix uses Lora is static, cuz we don't require to select from multiple lora and change on the fly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is another case that you may not use LoRA, so setting it to True ensures that you know you are using LoRA. And SGLangJax will replace the base_layer and initialize the zero buffer if you enable_static_lora. There are a few differences compared with disabling static lora.
tunix/generate/sglang_jax_sampler.py
Outdated
| if ( | ||
| mappings is None | ||
| or not enable_static_lora | ||
| or lora_target_modules is None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"not lora_target_modules" should have the same effect as "or lora_target_modules is None or len(lora_target_modules) == 0"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks.
tunix/generate/sglang_jax_sampler.py
Outdated
| self.engine = Engine(**self.args) | ||
|
|
||
| self.mappings = config.mapping_config.to_hf_mappings | ||
| self.to_hf_key_mappings = config.mapping_config.to_hf_mappings |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
redundant
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
| args["enable_deterministic_sampling"] = True | ||
| if config.init_with_random_weights: | ||
| args["load_format"] = "dummy" | ||
| args["disable_radix_cache"] = config.disable_radix_cache |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider put checkers into a separate function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
tunix/generate/sglang_jax_sampler.py
Outdated
| ) | ||
|
|
||
|
|
||
| def update_hf_key_mappings_with_lora( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably need to move this function to top of the file, other our internal might complain about not be able to find the function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
fd80b54 to
401c62f
Compare
401c62f to
48c01da
Compare

Resolves #825.
scripts/grpo_demo_llama3_qwen2.pyto run LoRA.sglang_jax_lora_test.pyto ensureupdate_paramsworks, and put it intotpu-tests.yml.verify_update_paramswill be executed whenVERIFY_UPDATE_PARAMS_KEYis configured.Test1: Run verification of
update_paramsJAX_COMPILATION_CACHE_DIR=/tmp/jit_cache python3 tests/generate/sglang_jax_lora_test.py.Test2: Run
scripts/grpo_demo_llama3_qwen2.pywithout LoRAJAX_COMPILATION_CACHE_DIR=/tmp/jit_cache python3 scripts/grpo_demo_llama3_qwen2.py --num-batches 2 --num-test-batches 1 --root-dir=/home/gcpuser/aolemila --rollout-engine sglang_jax.Test3: Run
scripts/grpo_demo_llama3_qwen2.pywith LoRAJAX_COMPILATION_CACHE_DIR=/tmp/jit_cache python3 scripts/grpo_demo_llama3_qwen2.py --num-batches 2 --num-test-batches 1 --root-dir=/home/gcpuser/aolemila --rollout-engine sglang_jax --enable-lora --lora-target-modules all.Reference
Colab Notebook
Checklist